{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uhYG8yG47rKW"
      },
      "source": [
        "# License\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\")\n",
        "```\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2xOk6PugLKrJ"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EGt8QZhj_TXc"
      },
      "outputs": [],
      "source": [
        "# Uncomment to install the covid_vhh_design package\n",
        "\n",
        "# !pip install git+https://github.com/google-research/google-research.git#subdirectory=covid_vhh_design"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a51RdViEciv2"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JAFVmV4j7-hX"
      },
      "outputs": [],
      "source": [
        "from IPython.display import display\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import seaborn as sns"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C2xfG8K58fou"
      },
      "outputs": [],
      "source": [
        "from covid_vhh_design import covid\n",
        "from covid_vhh_design import helper\n",
        "from covid_vhh_design import models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ct5kxN1dH0iV"
      },
      "outputs": [],
      "source": [
        "%config InlineBackend.figure_format = 'retina'\n",
        "\n",
        "pd.set_option('display.width', 200)\n",
        "pd.set_option('display.max_colwidth', None)\n",
        "pd.set_option('display.max_rows', 200)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t-c3-NZRLCYg"
      },
      "source": [
        "# Load the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WxTOe9FILPyj"
      },
      "outputs": [],
      "source": [
        "# Placeholder for global variables\n",
        "G = helper.Bunch()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9hctSrmNLS5e"
      },
      "outputs": [],
      "source": [
        "# Random state to control randomness\n",
        "G.random_state = np.random.RandomState(0)\n",
        "\n",
        "# The combined regressor/classifier models\n",
        "G.model = models.CombinedModel.load()\n",
        "\n",
        "# Encoder to onehot- and AAIndex-encode amino acid sequences\n",
        "G.encoder = models.SequenceEncoder()\n",
        "\n",
        "# BLI sequences used for making predictions below.\n",
        "G.bli = covid.load_df('bli_v2.csv')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "US-lpY80vjuc"
      },
      "source": [
        "# Input sequences\n",
        "Input sequences must be 125 amino acid long (as long as VHH-72). Only the natural 20 amino acids are allowed, not gaps or special amino acids.\n",
        "\n",
        "As an example, we make predictions for the parent sequence VHH-72, the best designed sequences with the lowest BLI KD binding values, and several baseline sequences obtained by randomly mutating or shuffling the parent sequence, and sampling amino acids randomly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KRw_766P_rKN"
      },
      "outputs": [],
      "source": [
        "def shuffle_sequence(sequence: str, random_state: np.random.RandomState) -\u003e str:\n",
        "  return ''.join(sequence[i] for i in random_state.permutation(len(sequence)))\n",
        "\n",
        "\n",
        "def get_random_sequence(\n",
        "    length: int, random_state: np.random.RandomState\n",
        ") -\u003e str:\n",
        "  return ''.join(random_state.choice(models.AMINO_ACIDS, length, replace=True))\n",
        "\n",
        "\n",
        "def mutate_sequence(\n",
        "    sequence: str, num_mutations: int, random_state: np.random.RandomState\n",
        ") -\u003e str:\n",
        "  mutant = list(sequence)\n",
        "  for pos in random_state.choice(len(sequence), num_mutations, replace=False):\n",
        "    mutant[pos] = random_state.choice(\n",
        "        list(set(models.AMINO_ACIDS) - {sequence[pos]})\n",
        "    )\n",
        "  return ''.join(mutant)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "acmCmOytsNJQ"
      },
      "outputs": [],
      "source": [
        "# Sequences to be scored\n",
        "G.seqs = dict(\n",
        "    # Best BLI sequence\n",
        "    best=G.bli.query('label == \"Seq1\"')['source_seq'].iloc[0],\n",
        "    # Parent sequence\n",
        "    parent=covid.PARENT_SEQ,\n",
        "    # Single mutant\n",
        "    mutant1=mutate_sequence(covid.PARENT_SEQ, 1, G.random_state),\n",
        "    # Double mutant\n",
        "    mutant2=mutate_sequence(covid.PARENT_SEQ, 2, G.random_state),\n",
        "    # Triple mutant\n",
        "    mutant3=mutate_sequence(covid.PARENT_SEQ, 3, G.random_state),\n",
        "    # Parent shuffled\n",
        "    shuffled=shuffle_sequence(covid.PARENT_SEQ, G.random_state),\n",
        "    # Randomly sampled amino acids\n",
        "    random=get_random_sequence(len(covid.PARENT_SEQ), G.random_state),\n",
        ")\n",
        "G.seqs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l4jpOtVcqiYC"
      },
      "source": [
        "# Making predictions\n",
        "Sequences can be made with `models.score_labeled_sequences`, which calls `encoder.encode_sequences(sequences)` to encode sequences, `model.predict` to predict binding scores, and returns a `DataFrame` with predictions.\n",
        "\n",
        "We use\n",
        "* `G.model.regressor.predict`: to predict normalized binding scores with the regressor. A binding score is positive float, where greater values indicate stronger binding. Specifically, predicted scores correspond to inverted AlphaSeq log KD values (score = 5.073646 - log KD), where 5.073646 is the maximum normalized log KD values in the training dataset.\n",
        "* `G.model.classifier.predict`: to predict binding probabilities between 0 (no binding) and 1 (binding) with the classifier.\n",
        "* `G.model.predict`: to make predictions with the combined regressor/classifier model. Output values correspond to to `G.model.regressor.predict(x) * G.model.classifier.predict(x)`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "14Iv7p2DvIlg"
      },
      "outputs": [],
      "source": [
        "def plot_scores(scores: pd.DataFrame) -\u003e None:\n",
        "  \"\"\"Plots model predictions.\"\"\"\n",
        "  df = scores.reset_index().melt(\n",
        "      id_vars='label', var_name='target_name', value_name='value'\n",
        "  )\n",
        "  _, ax = plt.subplots(figsize=(15, 5))\n",
        "  sns.barplot(\n",
        "      data=df,\n",
        "      x='target_name',\n",
        "      y='value',\n",
        "      hue='label',\n",
        "      palette='tab10',\n",
        "      ax=ax,\n",
        "  )\n",
        "  ax.set_xlabel('')\n",
        "  ax.set_ylabel('score')\n",
        "  ax.figure.canvas.draw()\n",
        "  ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right')\n",
        "  ax.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left', ncol=1, frameon=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PgDGt0fispse"
      },
      "outputs": [],
      "source": [
        "# Predict combined regressor/classifier scores.\n",
        "G.scores = models.score_labeled_sequences(G.model, G.encoder, G.seqs)\n",
        "plot_scores(G.scores)\n",
        "display(G.scores)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yNnbDfivwZ4q"
      },
      "outputs": [],
      "source": [
        "# Predict binding scores with the regressor.\n",
        "G.reg_scores = models.score_labeled_sequences(G.model.regressor, G.encoder, G.seqs)\n",
        "plot_scores(G.reg_scores)\n",
        "display(G.reg_scores)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "63n_syIVGjIT"
      },
      "outputs": [],
      "source": [
        "# Predict binding probabilities with the classifier.\n",
        "# Output values are probabilities with `proba=True`, and 0/1 with `proba=False`.\n",
        "G.cla_scores = models.score_labeled_sequences(G.model.classifier, G.encoder, G.seqs, proba=True)\n",
        "plot_scores(G.cla_scores)\n",
        "display(G.cla_scores)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
